{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "0",
   "metadata": {
    "lines_to_next_cell": 2
   },
   "source": [
    "# HuggingFace Chat Target - optional\n",
    "\n",
    "This notebook is designed to demonstrate **instruction models** that use a **chat template**, allowing users to experiment with structured chat-based interactions.  Non-instruct models are excluded to ensure consistency and reliability in the chat-based interactions. More instruct models can be explored on Hugging Face.\n",
    "\n",
    "## Key Points:\n",
    "\n",
    "1. **Supported Instruction Models**:\n",
    "   - This notebook supports the following **instruct models** that follow a structured chat template. These are examples, and more instruct models are available on Hugging Face:\n",
    "     - `HuggingFaceTB/SmolLM-360M-Instruct`\n",
    "     - `microsoft/Phi-3-mini-4k-instruct`\n",
    "\n",
    "     - `...`\n",
    "\n",
    "2. **Excluded Models**:\n",
    "   - Non-instruct models (e.g., `\"google/gemma-2b\"`, `\"princeton-nlp/Sheared-LLaMA-1.3B-ShareGPT\"`) are **not included** in this demo, as they do not follow the structured chat template required for the current local Hugging Face model support.\n",
    "\n",
    "3. **Model Response Times**:\n",
    "   - The tests were conducted using a CPU, and the following are the average response times for each model:\n",
    "     - `HuggingFaceTB/SmolLM-1.7B-Instruct`: 5.87 seconds\n",
    "     - `HuggingFaceTB/SmolLM-135M-Instruct`: 3.09 seconds\n",
    "     - `HuggingFaceTB/SmolLM-360M-Instruct`: 3.31 seconds\n",
    "     - `microsoft/Phi-3-mini-4k-instruct`: 4.89 seconds\n",
    "     - `Qwen/Qwen2-0.5B-Instruct`: 1.38 seconds\n",
    "     - `Qwen/Qwen2-1.5B-Instruct`: 2.96 seconds\n",
    "     - `stabilityai/stablelm-2-zephyr-1_6b`: 5.31 seconds\n",
    "     - `stabilityai/stablelm-zephyr-3b`: 8.37 seconds\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running model: Qwen/Qwen2-0.5B-Instruct\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average response time for Qwen/Qwen2-0.5B-Instruct: 106.50 seconds\n",
      "\n",
      "\n",
      "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[1m\u001b[34m🔹 Turn 1 - USER\u001b[0m\n",
      "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[34m  What is 3*3? Give me the solution.\u001b[0m\n",
      "\n",
      "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n",
      "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[33m  3 * 3 = 9\u001b[0m\n",
      "\n",
      "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\n",
      "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[1m\u001b[34m🔹 Turn 1 - USER\u001b[0m\n",
      "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[34m  What is 4*4? Give me the solution.\u001b[0m\n",
      "\n",
      "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[1m\u001b[33m🔸 ASSISTANT\u001b[0m\n",
      "\u001b[33m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "\u001b[33m  When you multiply two four-digit numbers together, the result will again be a four-digit number in\u001b[0m\n",
      "\u001b[33m      its own units place (20). Therefore,\u001b[0m\n",
      "\n",
      "\u001b[34m────────────────────────────────────────────────────────────────────────────────────────────────────\u001b[0m\n",
      "Qwen/Qwen2-0.5B-Instruct: 106.50 seconds\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "\n",
    "from pyrit.executor.attack import (\n",
    "    AttackExecutor,\n",
    "    ConsoleAttackResultPrinter,\n",
    "    PromptSendingAttack,\n",
    ")\n",
    "from pyrit.prompt_target import HuggingFaceChatTarget\n",
    "from pyrit.setup import IN_MEMORY, initialize_pyrit\n",
    "\n",
    "initialize_pyrit(memory_db_type=IN_MEMORY)\n",
    "\n",
    "# models to test\n",
    "model_id = \"Qwen/Qwen2-0.5B-Instruct\"\n",
    "\n",
    "# List of prompts to send\n",
    "prompt_list = [\"What is 3*3? Give me the solution.\", \"What is 4*4? Give me the solution.\"]\n",
    "\n",
    "# Dictionary to store average response times\n",
    "model_times = {}\n",
    "\n",
    "print(f\"Running model: {model_id}\")\n",
    "\n",
    "try:\n",
    "    # Initialize HuggingFaceChatTarget with the current model\n",
    "    target = HuggingFaceChatTarget(model_id=model_id, use_cuda=False, tensor_format=\"pt\", max_new_tokens=30)\n",
    "\n",
    "    # Initialize the attack\n",
    "    attack = PromptSendingAttack(objective_target=target)\n",
    "\n",
    "    # Record start time\n",
    "    start_time = time.time()\n",
    "\n",
    "    # Send prompts asynchronously\n",
    "    responses = await AttackExecutor().execute_multi_objective_attack_async(  # type: ignore\n",
    "        attack=attack,\n",
    "        objectives=prompt_list,\n",
    "    )\n",
    "\n",
    "    # Record end time\n",
    "    end_time = time.time()\n",
    "\n",
    "    # Calculate total and average response time\n",
    "    total_time = end_time - start_time\n",
    "    avg_time = total_time / len(prompt_list)\n",
    "    model_times[model_id] = avg_time\n",
    "\n",
    "    print(f\"Average response time for {model_id}: {avg_time:.2f} seconds\\n\")\n",
    "\n",
    "    # Print the conversations\n",
    "    for result in responses:\n",
    "        await ConsoleAttackResultPrinter().print_conversation_async(result=result)  # type: ignore\n",
    "\n",
    "except Exception as e:\n",
    "    print(f\"An error occurred with model {model_id}: {e}\\n\")\n",
    "    model_times[model_id] = None\n",
    "\n",
    "# Print the model average time\n",
    "if model_times[model_id] is not None:\n",
    "    print(f\"{model_id}: {model_times[model_id]:.2f} seconds\")\n",
    "else:\n",
    "    print(f\"{model_id}: Error occurred, no average time calculated.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyrit.memory import CentralMemory\n",
    "\n",
    "memory = CentralMemory.get_memory_instance()\n",
    "memory.dispose_engine()"
   ]
  }
 ],
 "metadata": {
  "jupytext": {
   "cell_metadata_filter": "-all"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
